import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, QuantileTransformer, PowerTransformer
from sklearn.neural_network import MLPRegressor
from sklearn.compose import TransformedTargetRegressor
from sklearn.metrics import mean_squared_error, PredictionErrorDisplay, median_absolute_error, r2_score, mean_absolute_error
import scipy.stats as stats
from scipy.interpolate import CubicHermiteSpline, PchipInterpolator
from scipy.stats import norm
from statsmodels.api import GLM, families
from statsmodels.tools.sm_exceptions import PerfectSeparationError

class LinearExtrapolator:
    def __init__(self, func, inverse_func, in_bounds, lower_slope, upper_slope):
        self.f = func
        self.i_f = inverse_func
        self.il, self.iu = in_bounds
        self.ol, self.ou = self.f(self.il), self.f(self.iu)
        self.ls = lower_slope
        self.us = upper_slope

    def __call__(self, x: np.ndarray, inverse=False):
        if not inverse:
            out = np.zeros_like(x)
            out[(x >= self.il) & (x <= self.iu)] = self.f(x[(x >= self.il) & (x <= self.iu)])
            out[x < self.il] = self.ol - self.ls * (self.il - x[x < self.il])
            out[x > self.iu] = self.ou + self.us * (x[x > self.iu] - self.iu)
            return out
        else:
            out = np.zeros_like(x)
            out[(x >= self.ol) & (x <= self.ou)] = self.i_f(x[(x >= self.ol) & (x <= self.ou)])
            out[x < self.ol] = self.il - (self.ol - x[x < self.ol]) / self.ls
            out[x > self.ou] = self.iu + (x[x > self.ou] - self.ou) / self.us
            return out

class ContinuousExtrapolator:
    def __init__(
            self, func, inverse_func, in_bounds,
            lower_f_slope, upper_f_slope,
            lower_target_slope=None, upper_target_slope=None,
            scale=1.0
    ):
        self.f, self.i_f = func, inverse_func
        self.il, self.iu = in_bounds
        self.ol, self.ou = self.f(self.il), self.f(self.iu)
        self.lfs, self.ufs = lower_f_slope, upper_f_slope
        self.lts = lower_target_slope if lower_target_slope is not None else lower_f_slope
        self.uts = upper_target_slope if upper_target_slope is not None else upper_f_slope
        if self.lts == self.lfs:
            self.l_lin = True
        else:
            self.l_scl = scale if self.lts < self.lfs else -scale
            self.l_lin = False
            self.l_b = self.l_scl / (self.lfs - self.lts)
            self.l_a = self.l_scl * self.l_b
        if self.uts == self.ufs:
            self.u_lin = True
        else:
            self.u_scl = scale if self.uts < self.ufs else -scale
            self.u_lin = False
            self.u_b = self.u_scl / (self.ufs - self.uts)
            self.u_a = self.u_scl * self.u_b

    def __call__(self, x: np.ndarray, inverse=False):
        if not inverse:
            out = np.zeros_like(x)
            out[(x >= self.il) & (x <= self.iu)] = self.f(x[(x >= self.il) & (x <= self.iu)])
            lower_x, upper_x = self.il - x[x < self.il], x[x > self.iu] - self.iu
            if self.l_lin:
                out[x < self.il] = self.ol - self.lfs * lower_x
            else:
                out[x < self.il] = self.ol - (self.lts * lower_x + self.l_scl - self.l_a / (lower_x + self.l_b))
            if self.u_lin:
                out[x > self.iu] = self.ou + self.ufs * upper_x
            else:
                out[x > self.iu] = self.ou + (self.uts * upper_x + self.u_scl - self.u_a / (upper_x + self.u_b))
            return out
        else:
            out = np.zeros_like(x)
            out[(x >= self.ol) & (x <= self.ou)] = self.i_f(x[(x >= self.ol) & (x <= self.ou)])
            lower_x, upper_x = self.ol - x[x < self.ol], x[x > self.ou] - self.ou
            if self.l_lin:
                out[x < self.ol] = self.il - lower_x / self.lfs
            else:
                interm_l = lower_x + self.lts * self.l_b - self.l_scl
                out[x < self.ol] = self.il - ((np.sqrt(np.square(interm_l) + 4 * self.l_a * self.lts) + interm_l) / (
                            2 * self.lts) - self.l_b)
            if self.u_lin:
                out[x > self.ou] = self.iu + upper_x / self.ufs
            else:
                interm_u = upper_x + self.uts * self.u_b - self.u_scl
                out[x > self.ou] = self.iu + ((np.sqrt(np.square(interm_u) + 4 * self.u_a * self.uts) + interm_u) / (
                            2 * self.uts) - self.u_b)
            return out

class OrderedQuantileTransformer:
    def __init__(self, warn = True):
        self.warn = warn
        
    def predict(self, newdata=None, inverse=False):
        if newdata is None:
            newdata = self.x_t if inverse else self.x
        
        newdata = np.asarray(newdata, dtype=np.float64)
        na_idx = np.isnan(newdata)
        
        if not inverse:
            newdata[~na_idx] = self._order_norm_trans(newdata[~na_idx])
        else:
            newdata[~na_idx] = self._inv_order_norm_trans(newdata[~na_idx])
        
        return newdata
    
    def _order_norm_trans(self, new_points):
        x_t_sorted = np.sort(self.x_t)
        x_sorted = np.sort(self.x)
        approx_values = np.interp(new_points, x_sorted, x_t_sorted, left=np.nan, right=np.nan)
        
        # Extrapolation if necessary
        if np.any(np.isnan(approx_values)):
            if self.warn:
                print('Warning: Transformations requested outside observed domain; logit approx. on ranks applied')
            
            if self.fit_res is not None:
                p = norm.ppf(self.fit_res.predict(np.vstack([np.ones(len(new_points)), new_points]).T))
                approx_values[np.isnan(approx_values)] = p[np.isnan(approx_values)]
        
        return approx_values
    
    def _inv_order_norm_trans(self, new_points_x_t):
        x_t_sorted = np.sort(self.x_t)
        x_sorted = np.sort(self.x)
        approx_values = np.interp(new_points_x_t, x_t_sorted, x_sorted, left=np.nan, right=np.nan)
        
        # Extrapolation if necessary
        if np.any(np.isnan(approx_values)):
            if self.warn:
                print('Warning: Transformations requested outside observed domain; logit approx. on ranks applied')
            
            if self.fit_res is not None:
                p = norm.ppf(self.fit_res.predict(np.vstack([np.ones(len(new_points_x_t)), new_points_x_t]).T))
                logits = np.log(-p / (1 - p))
                approx_values[np.isnan(approx_values)] = (logits[np.isnan(approx_values)] - self.fit_res.params[0]) / self.fit_res.params[1]
        
        return approx_values
    
    def __repr__(self):
        return f'OrderNorm Transformation with {len(self.x)} nonmissing obs and ' + \
               (f'ties\n - {len(np.unique(self.x))} unique values' if self.ties_status == 1 else 'no ties') + '\n' + \
               '- Original quantiles:\n' + \
               str(np.round(np.quantile(self.x, [0.25, 0.5, 0.75]), 3))
    
    def _fit(self, x, n_logit_fit=None):
        self.x = np.asarray(x, dtype=np.float64)
        self.ties_status = 0
        self.n_logit_fit = min(len(x), 10000) if n_logit_fit is None else n_logit_fit
        
        # Handle missing values
        na_idx = np.isnan(self.x)
        x_non_na = self.x[~na_idx]
        
        # Check for ties
        if len(np.unique(x_non_na)) < len(x_non_na):
            if self.warn:
                print('Warning: Ties in data, Normal distribution not guaranteed')
            self.ties_status = 1
        
        # Perform the quantile normalization
        ranks = np.argsort(np.argsort(x_non_na))
        q_x = (ranks + 0.5) / len(x_non_na)
        self.x_t = norm.ppf(q_x)
        
        # Fit the model for extrapolation
        keep_idx = np.round(np.linspace(0, len(x_non_na) - 1, self.n_logit_fit)).astype(int)
        x_red = np.sort(x_non_na)[keep_idx]
        q_red = (np.argsort(np.argsort(x_red)) + 0.5) / len(x_red)
        
        try:
            self.fit_res = GLM(q_red, np.vstack([np.ones(len(x_red)), x_red]).T, family=families.Binomial()).fit()
        except PerfectSeparationError:
            print("Perfect separation detected, skipping model fitting.")
            self.fit_res = None
        return self
    
    def fit(self, x):
        return self._fit(x.flatten())

    def transform(self, x):
        shape = x.shape
        return self.predict(newdata=x.flatten(), inverse=False).reshape(shape)
    
    def fit_transform(self, x):
        self.fit(x)
        return self.transform(x)
    
    def inverse_transform(self, x):
        shape = x.shape
        return self.predict(x.flatten(), inverse=True).reshape(shape)
    
    def __sklearn_clone__(self):
        return self

class SQNNormalizer:
    def __init__(self, s = 16, sigma = 1.0):
        self.s = int(s)
        self.sigma = sigma

    def fit(self, X: np.ndarray):
        self.X_min, self.X_max = X.min(), X.max()
        X_range = self.X_max - self.X_min

        pdf = stats.gaussian_kde(X, bw_method=self.sigma)
        self.cdf = np.vectorize(lambda x: pdf.integrate_box_1d(-np.inf, x), otypes=[np.float64])
        spl_x = np.linspace(self.X_min, self.X_max, self.s + 1)
        spl_y = self.cdf(spl_x)

        pchip_derivs = PchipInterpolator._find_derivatives(
            x=spl_x.reshape((spl_x.shape[0],) + (1,)*(spl_y.ndim-1)), y=spl_y
        )

        self.spline = CubicHermiteSpline(spl_x, spl_y, pchip_derivs, extrapolate=False)

        self.spline_ext = ContinuousExtrapolator(
            func=self.spline,
            inverse_func=np.vectorize(self.__inverse_spline, otypes=[np.float64]),
            in_bounds=[self.X_min, self.X_max],
            lower_f_slope=pchip_derivs[0],
            upper_f_slope=pchip_derivs[-1],
            lower_target_slope=1/X_range,
            upper_target_slope=1/X_range,
            scale=1
        )

        cdf_min, cdf_max = self.spline(X.min()), self.spline(X.max())
        ppf_slope_l = np.sqrt(2 * np.pi) * np.exp(stats.norm.ppf(cdf_min) ** 2 / 2)
        ppf_slope_u = np.sqrt(2 * np.pi) * np.exp(stats.norm.ppf(cdf_max) ** 2 / 2)

        self.ppf_ext = ContinuousExtrapolator(
            func=stats.norm.ppf,
            inverse_func=stats.norm.cdf,
            in_bounds=[cdf_min, cdf_max],
            lower_f_slope=ppf_slope_l,
            upper_f_slope=ppf_slope_u,
            lower_target_slope=np.sqrt(2 * np.pi),
            upper_target_slope=np.sqrt(2 * np.pi),
            scale=1
        )
        return self

    def transform(self, X: np.ndarray):
        redistributed = self.spline_ext(X)
        curved = self.ppf_ext(redistributed)
        return curved

    def __inverse_spline(self, y_value):
        roots = self.spline.solve(y_value, discontinuity=False, extrapolate=False)
        return roots[0]


    def inverse_transform(self, X):
        precurved = self.ppf_ext(X, inverse=True)
        predistributed = self.spline_ext(precurved, inverse=True)
        return predistributed

    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)

    def __sklearn_clone__(self):
        return SQNNormalizer(self.s, self.sigma)

class SQNTransformer:
    def __init__(self, s = 16, sigma = 0.1):
        self.s = s
        self.sigma = sigma

    def fit(self, X):
        self.one_d = X.ndim == 1
        if self.one_d:
            self.normalizer = SQNNormalizer(self.s, sigma=self.sigma)
            self.normalizer.fit(X)
        else:
            self.n_features = X.shape[-1]
            self.normalizers = [
                SQNNormalizer(self.s, sigma=self.sigma).fit(
                    X[:, i])
                for i in range(self.n_features)
            ]

        return self

    def transform(self, X):
        if self.one_d:
            return self.normalizer.transform(X)
        else:
            return np.stack([normalizer.transform(X[:, i]) for i, normalizer in enumerate(self.normalizers)], axis=-1)

    def inverse_transform(self, X):
        if self.one_d:
            return self.normalizer.inverse_transform(X)
        else:
            return np.stack([normalizer.inverse_transform(X[:, i]) for i, normalizer in enumerate(self.normalizers)],
                            axis=-1)

    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)

    def __sklearn_clone__(self):
        return SQNTransformer(self.s, self.sigma)

class IdentityTransformer:
    def __init__(self):
        pass
    def fit(self, X):
        return self
    def transform(self, X):
        return X
    def fit_transform(self, X):
        return X
    def inverse_transform(self, X):
        return X
    def __sklearn_clone__(self):
        return self

# 1. Load the dataset
california_housing = fetch_california_housing()
X, y = california_housing.data, california_housing.target

# 2. Preprocess the data (split and scale)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=43)

scaler_X = StandardScaler()
X_train_scaled = scaler_X.fit_transform(X_train)
X_test_scaled = scaler_X.transform(X_test)

transformers = [
    (StandardScaler(), "STD"),
    (QuantileTransformer(output_distribution='normal'), "CQN"),
    (OrderedQuantileTransformer(warn=False), "OQN"),
    (PowerTransformer(method='box-cox'), "BXC"),
    (PowerTransformer(method='yeo-johnson'), "YJN"),
    (SQNTransformer(s=16, sigma=0.2), "SQN (ours)"),
]

cols = 3
rows = int(np.ceil(len(transformers) / cols))

fig, axes = plt.subplots(rows, cols, figsize=(18, 12))

MODEL_SHAPE = (90, 45) # [Hidden layers only]

for i, (transformer, name) in enumerate(transformers):

    # 3. Define the neural network model
    mlp = MLPRegressor(hidden_layer_sizes=MODEL_SHAPE, activation='relu', solver='adam', max_iter=1000)

    # 4. Wrap the model with TransformedTargetRegressor
    model = TransformedTargetRegressor(regressor=mlp, transformer=transformer)

    # 5. Train the model
    model.fit(X_train_scaled, y_train)

    # 6. Evaluate the model
    y_pred_train = model.predict(X_train_scaled)
    y_pred_test = model.predict(X_test_scaled)

    train_mse = mean_squared_error(y_train, y_pred_train)
    test_rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))
    test_mdae = median_absolute_error(y_test, y_pred_test)
    test_mae = mean_absolute_error(y_test, y_pred_test)

    score_label = f"RMSE: {test_rmse:.3f}\nMAE: {test_mae:.3f}\nMdAE: {test_mdae:.3f}"

    ax = axes[divmod(i, 3)]
    # 7. Plot prediction error display
    display = PredictionErrorDisplay.from_predictions(y_test, y_pred_test, kind="actual_vs_predicted", ax=ax)
    ax.set_xlim(0, 6)
    ax.set_ylim(0, 6)
    ax.text(0.05, 0.95, score_label, 
         transform=ax.transAxes, 
         fontsize=12, 
         va='top', 
         bbox=dict(facecolor='white', alpha=0.5, boxstyle='round,pad=0.5'))
    ax.set_title(f"Prediction Error Plot - {name}")
plt.show()
